# -*- coding: utf-8 -*-
# Copyright 2025 BrainX Ecosystem Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import functools
from typing import Dict, Sequence, Union, Callable

import jax
import numpy as np
import tqdm.auto
from jax.tree_util import tree_map

import brainstate.environ
from brainpy.mixin import SupportOnline
from brainpy import math as bm, tools
from brainpy.algorithms.online import get, OnlineAlgorithm, RLS
from brainpy.context import share
from brainpy.dynsys import DynamicalSystem
from brainpy.helpers import clear_input
from brainpy.runners import _call_fun_with_share
from brainpy.types import ArrayType, Output
from ._utils import format_ys
from .base import DSTrainer

__all__ = [
    'OnlineTrainer',
    'ForceTrainer',
]


class OnlineTrainer(DSTrainer):
    """Online trainer for models with recurrent dynamics.

    For more parameters, users should refer to :py:class:`~.DSRunner`.

    Parameters::

    target: DynamicalSystem
      The target model to train.

    fit_method: OnlineAlgorithm, Callable, dict, str
      The fitting method applied to the target model.

      - It can be a string, which specify the shortcut name of the training algorithm.
        Like, ``fit_method='rls'`` means using the RLS method.
        All supported fitting methods can be obtained through
        :py:func:`~.get_supported_online_methods`.
      - It can be a dict, whose "name" item specifies the name of the training algorithm,
        and the others parameters specify the initialization parameters of the algorithm.
        For example, ``fit_method={'name': 'rls', 'alpha': 0.1}``.
      - It can be an instance of :py:class:`brainpy.algorithms.OnlineAlgorithm`.
        For example, ``fit_meth=bp.algorithms.RLS(alpha=1e-5)``.
      - It can also be a callable function.

    kwargs: Any
      Other general parameters please see :py:class:`~.DSRunner`.
    """

    def __init__(
        self,
        target: DynamicalSystem,
        fit_method: Union[OnlineAlgorithm, Callable, Dict, str] = None,
        **kwargs
    ):
        super().__init__(target=target, **kwargs)

        # get all trainable nodes
        nodes = self.target.nodes(level=-1, include_self=True).subset(DynamicalSystem).unique()
        self.train_nodes = tuple([node for node in nodes.values() if isinstance(node.mode, bm.TrainingMode)])
        if len(self.train_nodes) == 0:
            raise ValueError('Found no trainable nodes.')

        # check the required interface in the trainable nodes
        self._check_interface()

        # training method
        if fit_method is None:
            fit_method = RLS(alpha=1e-7)
        elif isinstance(fit_method, str):
            fit_method = get(fit_method)()
        elif isinstance(fit_method, dict):
            name = fit_method.pop('name')
            fit_method = get(name)(**fit_method)
        self.fit_method = fit_method
        if not callable(fit_method):
            raise ValueError(f'"train_method" must be an instance of callable function, '
                             f'but we got {type(fit_method)}.')

        # set the training method
        for node in self.train_nodes:
            node.online_fit_by = fit_method

        # initialize the fitting method
        for node in self.train_nodes:
            node.online_init()

        # training function
        self._f_fit_compiled = dict()

    def __repr__(self):
        name = self.__class__.__name__
        indent = ' ' * len(name)
        indent2 = indent + " " * len("target")
        return (f'{name}(target={tools.repr_context(str(self.target), indent2)}, \n'
                f'{indent}jit={self.jit}, \n'
                f'{indent}fit_method={self.fit_method})')

    def predict(
        self,
        inputs: Union[ArrayType, Sequence[ArrayType], Dict[str, ArrayType]],
        reset_state: bool = False,
        shared_args: Dict = None,
        eval_time: bool = False
    ) -> Output:
        """Prediction function.

        What's different from `predict()` function in :py:class:`~.DynamicalSystem` is that
        the `inputs_are_batching` is default `True`.

        Parameters::

        inputs: ArrayType
          The input values.
        reset_state: bool
          Reset the target state before running.
        shared_args: dict
          The shared arguments across nodes.
        eval_time: bool
          Whether we evaluate the running time or not?

        Returns::

        output: ArrayType
          The running output.
        """
        outs = super().predict(inputs=inputs,
                               reset_state=reset_state,
                               shared_args=shared_args,
                               eval_time=eval_time)
        for node in self.train_nodes:
            node.fit_record.clear()
        return outs

    def fit(
        self,
        train_data: Sequence,
        reset_state: bool = False,
        shared_args: Dict = None,
    ) -> Output:
        with brainstate.environ.context(fit=True):
            if shared_args is None: shared_args = dict()
            shared_args['fit'] = shared_args.get('fit', True)
            shared_args = tools.DotDict(shared_args)

            # checking training and testing data
            if not isinstance(train_data, (list, tuple)):
                raise ValueError(f"{self.__class__.__name__} only support "
                                 f"training data with the format of (X, Y) pair, "
                                 f"but we got a {type(train_data)}.")
            if len(train_data) != 2:
                raise ValueError(f"{self.__class__.__name__} only support "
                                 f"training data with the format of (X, Y) pair, "
                                 f"but we got a sequence with length {len(train_data)}")
            xs, ys = train_data

            # reset the model states
            if reset_state:
                num_batch = self._get_input_batch_size(xs)
                self.target.reset(num_batch)
                self.reset_state()

            # format input/target data
            ys = format_ys(self, ys)
            num_step = self._get_input_time_step(xs=xs)

            indices = np.arange(self.i0, num_step + self.i0, dtype=np.int_)
            if self.data_first_axis == 'B':
                xs = tree_map(lambda x: bm.moveaxis(x, 0, 1),
                              xs,
                              is_leaf=lambda x: isinstance(x, bm.Array))
                ys = tree_map(lambda y: bm.moveaxis(y, 0, 1),
                              ys,
                              is_leaf=lambda y: isinstance(y, bm.Array))

            # init monitor
            for key in self._monitors.keys():
                self.mon[key] = []  # reshape the monitor items

            # init progress bar
            if self.progress_bar:
                self._pbar = tqdm.auto.tqdm(total=num_step)
                self._pbar.set_description(f"Train {num_step} steps: ", refresh=True)

            # prediction
            xs = (xs,) if not isinstance(xs, (tuple, list)) else xs
            outs, hists = self._fit(indices, xs=xs, ys=ys, shared_args=shared_args)

            # close the progress bar
            if self.progress_bar:
                self._pbar.close()

            # post-running for monitors
            if self.numpy_mon_after_run:
                hists = tree_map(lambda a: np.asarray(a), hists, is_leaf=lambda a: isinstance(a, bm.Array))
            for key in hists.keys():
                self.mon[key] = hists[key]
            self.i0 += num_step
            return outs

    def _fit(self,
             indices: ArrayType,
             xs: Sequence,
             ys: Dict[str, ArrayType],
             shared_args: Dict = None):
        """Predict the output according to the inputs.

        Parameters::

        indices: ArrayType
          The running indices.
        ys: dict
          Each tensor should have the shape of `(num_time, num_batch, num_feature)`.
        shared_args: optional, dict
          The shared keyword arguments.

        Returns::

        outputs, hists
          A tuple of pair of (outputs, hists).
        """
        hists = bm.for_loop(functools.partial(self._step_func_fit, shared_args=shared_args),
                            (indices, xs, ys),
                            jit=self.jit['fit'])
        hists = tree_map(lambda x: bm.moveaxis(x, 0, 1),
                         hists,
                         is_leaf=lambda x: isinstance(x, bm.Array))
        return hists

    def _step_func_fit(self, i, xs: Sequence, ys: Dict, shared_args=None):
        if shared_args is None:
            shared_args = dict()
        share.save(t=i * self.dt, dt=self.dt, i=i, **shared_args)

        # input step
        clear_input(self.target)
        self._step_func_input()

        # update step
        out = self.target(*xs)

        # monitor step
        monitors = self._step_func_monitor()
        for node in self.train_nodes:
            fit_record = monitors.pop(f'{node.name}-fit_record')
            target = ys[node.name]
            node.online_fit(target, fit_record)

        # finally
        if self.progress_bar:
            jax.debug.callback(lambda *args: self._pbar.update(), ())
        return out, monitors

    def _check_interface(self):
        for node in self.train_nodes:
            if not isinstance(node, SupportOnline):
                raise TypeError(
                    f'The node \n\n{node}\n\n'
                    f'is set to be trainable with {self.__class__.__name__} method. \n'
                    f'{self.__class__.__name__} only support training nodes that are instances '
                    f'of {SupportOnline}. '
                )

    def _step_func_monitor(self):
        res = dict()
        for key, val in self._monitors.items():
            if callable(val):
                res[key] = _call_fun_with_share(val)
            else:
                (variable, idx) = val
                if idx is None:
                    res[key] = variable.value
                else:
                    res[key] = variable[bm.asarray(idx)]
        if share.load('fit'):
            for node in self.train_nodes:
                res[f'{node.name}-fit_record'] = node.fit_record
        return res


class ForceTrainer(OnlineTrainer):
    """FORCE learning."""

    def __init__(self, target, alpha=1., **kwargs):
        super(ForceTrainer, self).__init__(target=target,
                                           fit_method=RLS(alpha=alpha),
                                           **kwargs)
